###############################################################################-
# Author: Pietryka
# Contact: matthew.pietryka@gmail.com
# Purpose: create Figure 4: Comparing peers and parents 
# Notes:
###############################################################################-


#  1. Load Packages  =====================


library(dplyr)
library(tidyr)
library(purrr)
library(readr)
library(estimatr)
library(janitor)
library(ggplot2)
library(forcats)
library(stringr)
library(glue)

#  2. Load Data =====================

stacked_df <- read_rds("data-files/stacked_df.rds")
source("Plot Preferences.r")


#  3. Functions for plots =====================

# remove trailing zero from string
remove_trailing_zero <- function(x) str_replace(x, "(\\d{4})0", "\\1")

# regression with robust standard errors
ols_robust <- function(a_formula) {
  lm_robust(a_formula,  data = stacked_df)
}

predict_y <- function(mod){
  
  column_name <-  attr(mod$terms, "term.labels")
  intercept = coef(mod)[["(Intercept)"]]
  beta = coef(mod)[[column_name]]
  lb   = mod$conf.low
  ub   = mod$conf.high
  p_value = mod$p.value[[column_name]]
  
  
  new_df <- tibble(
    `(Intercept)` = c(1, 1),
    x_value =  c(0, 1)
  ) |> 
    rename(!!column_name := x_value)
  
  predict_df <- predict(mod, newdata = new_df, interval =  "confidence") |> 
    as.data.frame() |> 
    as_tibble() |> 
    rename(yhat = fit.fit, lb = fit.lwr, ub = fit.upr)
  
  new_df |> 
    rename(x_value := !!column_name) |> 
    bind_cols(predict_df) |> 
    mutate(
      intercept = intercept,
      beta = beta,
      p_value = p_value
    )
  
  
}




# 4. Generate combinations of variables =============

y_vars <- c(
  "turnout_2016_f", "turnout_2018_f", "turnout_2020_f"
)

x_vars <- c(
  "turnout_2016_rm", 
  "parents_turnout_mean0_f" 
)

combos_df <- crossing(y = seq_along(y_vars), x = seq_along(x_vars)) %>% 
  mutate(
    y_name = y_vars[y],
    x_name = x_vars[x]
  ) %>% 
  mutate(
    y_year = remove_trailing_zero(y_name) %>% parse_number(),
    x_year = remove_trailing_zero(x_name) %>% parse_number()
  ) 



# 5. Calculate turnout by each X variable in each year ==============
ols_long_df <- combos_df %>%
  mutate(the_form = map2(x_name, y_name,  ~paste(.y, .x, sep = "~") %>% 
                           as.formula()),
         ols_mod = map(the_form, ols_robust),
         predict_df = map(ols_mod,  predict_y)
         ) %>%
  select(-x, -y) %>% 
  clean_names() %>% 
  rename(x = x_name, y = y_name)




# 6. clean the names, years, etc ==================

clean_var_names <- function(x){
  x %>% 
    remove_trailing_zero() %>% 
    str_replace("_f", "_a") %>% 
    str_replace("_rm", "_b") %>% 
    str_replace("mean0", "2008-14")  %>% 
    str_remove("_turnout")  %>% 
    str_replace("^turnout_", "student_")
}


ols_clean_df <- ols_long_df %>% 
  unnest(predict_df) %>% 
  mutate(across(c(x, y), clean_var_names)) %>% 
  separate(x, into = c("type_x", "year_x", "roommate_x"), sep = "_") %>% 
  separate(y, into = c("type_y", "year_y", "roommate_y"), sep = "_") %>% 
  unite(col = "lab", type_y:roommate_x, remove = FALSE) 








# 7. Make the plot =================================




plot_df <-  ols_clean_df %>% 
  mutate(
    type = ifelse(type_x == "parents", "Parents, 2008-14", "Roommate in 2016"),
    facet_lab = year_y
  )  %>% 
  mutate(x_value_unique = case_when(
    x_value == 0 & type_x == "student"   ~ 5L,
    x_value == 1 & type_x == "student" ~ 4L,
    x_value == 0 & type_x == "parents" ~ 2L,
    x_value == 1 & type_x == "parents"  ~ 1L
  ))




se_plot <- plot_df %>% 
  ggplot(aes(x = x_value_unique, y = yhat, shape = factor(x_value_unique), color = type_x, fill = type_x)) +
  facet_wrap(~facet_lab, ncol = 1) +
  geom_linerange( aes(ymin = lb, ymax = ub), linewidth = 0.9) +
  geom_point(size = 3, fill = "white", stroke = 1.5) +
  coord_flip() +
  format_plot(base_size = 14)   +
  xlab(NULL) +
  ylab("Percent voting") +
  ggtitle("Turnout by year", "Among students whose ...")  +
  scale_colour_grey(start = 0.5, end = 0) +
  scale_y_continuous(
    labels = scales::label_percent(accuracy = NULL, scale = 100),
    breaks = c(seq(.50, .80, .10))
  ) +
  scale_x_continuous(
    breaks = names(se_plot_x_labs) |> as.integer(), 
    labels   = se_plot_x_labs, 
    expand = expansion(mult = 0.1, add = 0)
  ) +
  scale_shape_manual(values = se_plot_shapes) +
  theme(
    legend.position = "none",
    panel.spacing = unit(1, "lines"),
    strip.text = element_text(
      family = "Arial Narrow",
      face = "bold",
      size = rel(1.1),
      hjust = 0.5
    ),
    plot.title = element_text(family = "Arial Narrow",  margin = margin(10,0,0,0)),
    plot.subtitle = element_text(family = "Arial Narrow",  margin = margin(0,0,0,0)),
    plot.title.position = "plot"
  )



graphics.off()
windows(width = 4, height = 5)

se_plot 




# 8. Save =============================================

showtext_begin()
showtext_opts(dpi = 600)
ggsave(se_plot, filename = "Results/se_plot.png", height = 5, width = 4)
ggsave(se_plot, filename = "Results/se_plot.tiff", height =  5, width = 4)
showtext_end()
